Introduction

Hierarchical modeling is often useful because data are organized as hierarchies or in multiple levels of aggregation. However, such models are typically complex to implement and analyze because of their complexity and possibly because of the lack of sufficient data. Traditionally, statistical software had to be written from scratch to fit hierarchical models (especially if Bayesien inference was employed), which made it difficult to substantial exploratory work with these models.

Recently, a number of software packages have emerged that allow one to automate many of the difficult aspects of Bayesian hierarchical modeling. In particular, the specification and implementation of Markov chain Monte Carlo samplers has been automated to the point where we do not have to spend copious amounts of time implementing conditional distributions of Gibbs samplers. This lecture will explore one such package, greta, for doing these computations.

Motivating Example

We will return to the analysis of PM10 and mortality data and examine data from 20 large cities in the United States. For data pertaining to a single city, a typical time series regression model will look as follows. We use a Poisson regression to model the count outcome (daily numbers of deaths) and use a log-linear predictor containing splines of temperature (tmpd), date, and PM10.

library(splines)
library(dplyr)
library(broom)

dat <- readRDS("data/nmmaps/ny.rds") %>%
        select(death, tmpd, date, pm10tmean)

fit <- glm(death ~ ns(tmpd, 3) + ns(date, 8 * 19) + pm10tmean,
           data = dat, family = stats::poisson)
tidy(fit) %>%
        filter(term == "pm10tmean")
# A tibble: 1 x 5
  term      estimate std.error statistic  p.value
  <chr>        <dbl>     <dbl>     <dbl>    <dbl>
1 pm10tmean 0.000862  0.000241      3.57 0.000355

The coefficient for pm10tmean is the primary target of interest. Here we would interpret it as a 0.87% increase in mortality for a 10 unit increase in PM10.

With data from 20 cities, we can fit this same model independently to each city’s data and get a sense of what the coefficients for PM10 look like (the log-relative risks).

infiles <- dir("data/nmmaps", glob2rx("*.rds"), full.names = TRUE)
fit.f <- lapply(infiles, function(file) {
        dat <- readRDS(file)
        glm(death ~ ns(tmpd, 3) + ns(date, 8 * 19) + pm10tmean, 
            data = dat, family = stats::poisson)
})
names(fit.f) <- sub(".rds", "", basename(infiles), fixed = TRUE)
results <- lapply(fit.f, tidy) %>%
        bind_rows(.id = "city") %>%
        filter(term == "pm10tmean")

The results are shown here for each city.

results
# A tibble: 20 x 6
   city  term        estimate std.error statistic     p.value
   <chr> <chr>          <dbl>     <dbl>     <dbl>       <dbl>
 1 atla  pm10tmean -0.000875  0.000564    -1.55   0.121      
 2 aust  pm10tmean -0.000931  0.00125     -0.747  0.455      
 3 balt  pm10tmean -0.0000117 0.000363    -0.0322 0.974      
 4 bost  pm10tmean  0.000865  0.000749     1.15   0.248      
 5 chic  pm10tmean  0.000390  0.0000768    5.07   0.000000388
 6 clev  pm10tmean  0.000258  0.000126     2.05   0.0401     
 7 denv  pm10tmean  0.000250  0.000227     1.10   0.270      
 8 det   pm10tmean  0.000324  0.000125     2.58   0.00982    
 9 dlft  pm10tmean  0.000926  0.000324     2.86   0.00423    
10 hous  pm10tmean  0.000709  0.000206     3.45   0.000560   
11 la    pm10tmean  0.000408  0.000155     2.63   0.00850    
12 miam  pm10tmean -0.000559  0.000537    -1.04   0.298      
13 minn  pm10tmean  0.0000729 0.000217     0.335  0.737      
14 no    pm10tmean  0.000685  0.000927     0.739  0.460      
15 ny    pm10tmean  0.000862  0.000241     3.57   0.000355   
16 phoe  pm10tmean  0.000452  0.000246     1.84   0.0657     
17 pitt  pm10tmean  0.000493  0.000120     4.13   0.0000368  
18 ral   pm10tmean -0.00171   0.00149     -1.15   0.249      
19 seat  pm10tmean -0.000247  0.000213    -1.16   0.245      
20 stlo  pm10tmean  0.000852  0.000966     0.882  0.378      

Our primary goal here is to combine the log-relative risk estimates to obtain a single “overall” risk estimate that summarizes the data from all 20 cities. To do this, we will take a two-stage approach where we first compute risk estimates for each city independently and then combine them in a second stage using a Normal hierarchical model.

But first, the greta package.

Using the greta Package

The greta package is written by Nick Golding and serves as a package for fitting complex (often Bayesian hierarchical) models in R. It is similar in spirit to WinBUGS or JAGS, and more recently Stan, but it has a number of high-level and low-level differences:

  • Its syntax is R based, unlike Stan and WinBUGS which create their own model specification language.

  • Underlying implementation is done using Google TensorFlow, which allows for users to immediately take advantage of TensorFlow capabilities should they be available (GPU/TPU computation, parallelization).

Single-city GLM with greta

As an example of how to use the greta package we will start by fitting a Bayesian version of the single-city model that we fit to the New York City data above. Here, we will use the same log-linear Poisson model, but will add prior distributions to the regression model parameters.

First we can load the greta package and read in the data for New York City.

library(greta)

dat <- readRDS("data/nmmaps/ny.rds") %>%
        select(death, tmpd, date, pm10tmean)

Then we need to create the model matrix (design matrix) and the outcome vector. To do this we use the same model formula that was used in the GLM in the previous section.

mm <- model.matrix(death ~ ns(tmpd, 3) + ns(date, 8 * 19) + pm10tmean,
                   data = dat)
y <- dat$death[!is.na(dat$pm10tmean)]

Now that we have the data, we need to specify the parameters. For this model, the parameters are a vector of coefficients for the regression model. For a Bayesian formulation, we need to specify the prior distribution for them. Here we will use a \(\mathcal{N}(0, 10^2)\) distribution as the prior (independently) for all the coefficients.

beta <- normal(0, 10, dim = ncol(mm))

We can print out the beta object to see what the greta package does here.

head(beta)
greta array (operation)

     [,1]
[1,]  ?  
[2,]  ?  
[3,]  ?  
[4,]  ?  
[5,]  ?  
[6,]  ?  

Essentially, this is a \(157\times 1\) array of unknown parameters.

After setting the prior distributions, we need to specify the log-linear predictor for the Poisson model.

log.mu <- mm %*% beta
mu <- exp(log.mu)
head(mu)
greta array (operation)

     [,1]
[1,]  ?  
[2,]  ?  
[3,]  ?  
[4,]  ?  
[5,]  ?  
[6,]  ?  

Finally, we can specify the probability distribution for the data y as coming from a Poisson distribution.

distribution(y) <- greta::poisson(mu)

Note that we use the full function name greta::poisson() here because there are multiple functions named poisson() in different packages and we do not want any confusion.

Once all of this is specified we need to create a model object with the model() function. Here, we need to pass any arrays containing unknown parameters (i.e. the random elements in a Bayesian formulation). In this example, that is just the vector beta.

mod <- model(beta)

Before doing any model fitting, it can be useful to plot a graphical representation of the model to make sure that everything was properly specified.

plot(mod)

Once we have confirmed that the model is properly specified, we can fit it using Markov chain Monte Carlo to sample from the posterior distribution of beta.

In this invocation of the mcmc() function from greta, we specify:

  • The model object mod

  • The sampler should use a Hamiltonian Monte Carlo sampler (the alternative is a random walk Metropolis-Hastings)

  • We should draw 1,000 samples from the chain (after a 1,000 iteration warmup period)

  • We should only sample a single chain (the default is 4)

The output from mcmc() by default gives progress on the warmup and the sampling.

r <- greta::mcmc(mod, 
                 sampler = hmc(), 
                 n_samples = 1000, 
                 chains = 1)

    warmup                                           0/1000 | eta:  ?s          
    warmup ==                                       50/1000 | eta: 16s | 20% bad
    warmup ====                                    100/1000 | eta: 12s | 10% bad
    warmup ======                                  150/1000 | eta: 10s | 7% bad 
    warmup ========                                200/1000 | eta:  9s | 5% bad 
    warmup ==========                              250/1000 | eta:  8s | 4% bad 
    warmup ===========                             300/1000 | eta:  7s | 3% bad 
    warmup =============                           350/1000 | eta:  7s | 3% bad 
    warmup ===============                         400/1000 | eta:  6s | 2% bad 
    warmup =================                       450/1000 | eta:  6s | 2% bad 
    warmup ===================                     500/1000 | eta:  5s | 2% bad 
    warmup =====================                   550/1000 | eta:  5s | 2% bad 
    warmup =======================                 600/1000 | eta:  4s | 2% bad 
    warmup =========================               650/1000 | eta:  3s | 2% bad 
    warmup ===========================             700/1000 | eta:  3s | 1% bad 
    warmup ============================            750/1000 | eta:  2s | 1% bad 
    warmup ==============================          800/1000 | eta:  2s | 1% bad 
    warmup ================================        850/1000 | eta:  1s | 1% bad 
    warmup ==================================      900/1000 | eta:  1s | 1% bad 
    warmup ====================================    950/1000 | eta:  0s | 1% bad 
    warmup ====================================== 1000/1000 | eta:  0s | 1% bad 

  sampling                                           0/1000 | eta:  ?s          
  sampling ==                                       50/1000 | eta:  9s          
  sampling ====                                    100/1000 | eta:  7s          
  sampling ======                                  150/1000 | eta:  6s          
  sampling ========                                200/1000 | eta:  6s          
  sampling ==========                              250/1000 | eta:  6s          
  sampling ===========                             300/1000 | eta:  5s          
  sampling =============                           350/1000 | eta:  4s          
  sampling ===============                         400/1000 | eta:  4s          
  sampling =================                       450/1000 | eta:  4s          
  sampling ===================                     500/1000 | eta:  3s          
  sampling =====================                   550/1000 | eta:  3s          
  sampling =======================                 600/1000 | eta:  3s          
  sampling =========================               650/1000 | eta:  2s          
  sampling ===========================             700/1000 | eta:  2s          
  sampling ============================            750/1000 | eta:  2s          
  sampling ==============================          800/1000 | eta:  1s          
  sampling ================================        850/1000 | eta:  1s          
  sampling ==================================      900/1000 | eta:  1s          
  sampling ====================================    950/1000 | eta:  0s          
  sampling ====================================== 1000/1000 | eta:  0s          

The execution of the Hamiltonian Monte Carlo sampler is done through Google TensorFlow which allows us to take advantage of the parallelization built into TensorFlow (primarily for matrix/array computations). On a 2016 MacBook Pro, the sampling process uses 3 processors.

The object returned by greta::mcmc() can be fed into functions from the bayesplot packages. Here we will plot the trace plot of the pm10tmean variable (which is the very last, hence 157th, coefficient).

library(bayesplot)
mcmc_trace(r, "beta[157,1]")

We can also compute the posterior mean and standard deviation.

beta.m <- as.matrix(r)
mean(beta.m[, 157])
[1] 0.001838654
sd(beta.m[, 157])
[1] 0.00020908

Note that the posterior mean here is quite a bit bigger than the maximum likelihood estimate shown in the previous section. That said, it’s likely that we have not run our sampler for long enough as 1,000 iterations is a very small number of iterations for almost any MCMC sampler.

Normal Approximations

If we want to use a two-stage approach to combine the data from the 20 cities into a single overall risk estimate, we need to check and see if the profile likelihood for the PM10 coefficient in the Poisson model is reaosnably well-approximated with a Normal distribution centered around its maximum likelihood estimate.

We can first compute the profile log-likelihood for the pm10tmean coefficient numerically.

profileLL <- function(x) {
        form <- reformulate(c("ns(tmpd, 3)", "ns(date, 8*19)",
                              sprintf("offset(I(%f * pm10tmean))", x)),
                            response = "death")
        fit <- glm(form, data = dat, family = stats::poisson)
        logLik(fit)
}
profileLL <- Vectorize(profileLL)
x <- seq(0, 0.002, len = 100)
p <- profileLL(x)

Then we can plot the profile likelihood function.

library(ggplot2)
ggplot(mapping = aes(x, exp(p - max(p)))) + 
        geom_line() +
        xlab(expression(beta)) + 
        ylab("Profile Likelihood") + 
        theme_bw()

Since the profile likelihood looks very close to a Normal distribution, it seems reasonable that we can use the two-stage model here without having to use the full Poisson likelihood for each city.

Two-Stage Model

We can specify a two-stage Normal hierarchical model as follows.

\[\begin{eqnarray*} \hat{\beta}_c\mid\beta_c & \sim & \mathcal{N}(\beta_c, \hat{\sigma}_c^2)\\ \beta_c \mid \mu,\tau & \sim & \mathcal{N}(\mu, \tau^2)\\ \mu & \sim & \mathcal{N}(0, 10^2)\\ \tau & \sim & Unif(0, 0.001) \end{eqnarray*}\]

Here, the \(\hat{\beta}_c\) and \(\hat{\sigma}_c^2\) are obtained from the first stage GLM fit and are the maximum likelihood estimates. The parameter \(\mu\) is our overall log-relative risk estimate for the 20 cities and \(\tau\) is the “natural heterogeneity” in risk (i.e. unexplained by statistical variation) across the 20 cities.

Since we already fit each of the single-city models in the first section above, we can simply recall the results here.

results
# A tibble: 20 x 6
   city  term        estimate std.error statistic     p.value
   <chr> <chr>          <dbl>     <dbl>     <dbl>       <dbl>
 1 atla  pm10tmean -0.000875  0.000564    -1.55   0.121      
 2 aust  pm10tmean -0.000931  0.00125     -0.747  0.455      
 3 balt  pm10tmean -0.0000117 0.000363    -0.0322 0.974      
 4 bost  pm10tmean  0.000865  0.000749     1.15   0.248      
 5 chic  pm10tmean  0.000390  0.0000768    5.07   0.000000388
 6 clev  pm10tmean  0.000258  0.000126     2.05   0.0401     
 7 denv  pm10tmean  0.000250  0.000227     1.10   0.270      
 8 det   pm10tmean  0.000324  0.000125     2.58   0.00982    
 9 dlft  pm10tmean  0.000926  0.000324     2.86   0.00423    
10 hous  pm10tmean  0.000709  0.000206     3.45   0.000560   
11 la    pm10tmean  0.000408  0.000155     2.63   0.00850    
12 miam  pm10tmean -0.000559  0.000537    -1.04   0.298      
13 minn  pm10tmean  0.0000729 0.000217     0.335  0.737      
14 no    pm10tmean  0.000685  0.000927     0.739  0.460      
15 ny    pm10tmean  0.000862  0.000241     3.57   0.000355   
16 phoe  pm10tmean  0.000452  0.000246     1.84   0.0657     
17 pitt  pm10tmean  0.000493  0.000120     4.13   0.0000368  
18 ral   pm10tmean -0.00171   0.00149     -1.15   0.249      
19 seat  pm10tmean -0.000247  0.000213    -1.16   0.245      
20 stlo  pm10tmean  0.000852  0.000966     0.882  0.378      

From this we will need the vector of estimates and standard errors for the hierarchical model.

betahat <- results$estimate
sdhat <- results$std.error

Now we need to use the greta functions to specify our hierarchical model. First, we will give the prior distributions for \(\mu\) and \(\tau\) as Normal and uniform, respectively.

mu <- normal(0, 1)
tau <- uniform(0, 0.001)

Then we will specify \(\beta_c\) as having a Normal distribution with mean \(\mu\) and standard deviation \(\tau\) for every city.

betac <- normal(mu, tau, dim = length(betahat))

Finally, we will specify that the \(\hat{\beta}_c\)s also follow a Normal distribution. In this example, the \(\hat{\beta}_c\)s are the “data/outcome” for the model.

distribution(betahat) <- normal(betac, sdhat)

Now we can create our model object and plot it to make sure it’s properly specified.

mod <- model(betac, mu, tau)
plot(mod)

After verifying that the model is properly specified, we can run the mcmc() functions to draw samples from the posterior distribution. We will use similar settings as before, except now we will draw 5,000 samples.

r <- greta::mcmc(mod, 
                 sampler = hmc(), 
                 n_samples = 5000, 
                 chains = 1)

    warmup                                           0/1000 | eta:  ?s          
    warmup ==                                       50/1000 | eta: 21s          
    warmup ====                                    100/1000 | eta: 15s          
    warmup ======                                  150/1000 | eta: 12s          
    warmup ========                                200/1000 | eta: 11s          
    warmup ==========                              250/1000 | eta:  9s          
    warmup ===========                             300/1000 | eta:  8s          
    warmup =============                           350/1000 | eta:  8s          
    warmup ===============                         400/1000 | eta:  7s          
    warmup =================                       450/1000 | eta:  6s          
    warmup ===================                     500/1000 | eta:  6s          
    warmup =====================                   550/1000 | eta:  5s          
    warmup =======================                 600/1000 | eta:  4s          
    warmup =========================               650/1000 | eta:  4s          
    warmup ===========================             700/1000 | eta:  3s          
    warmup ============================            750/1000 | eta:  3s          
    warmup ==============================          800/1000 | eta:  2s          
    warmup ================================        850/1000 | eta:  2s          
    warmup ==================================      900/1000 | eta:  1s          
    warmup ====================================    950/1000 | eta:  1s          
    warmup ====================================== 1000/1000 | eta:  0s          

  sampling                                           0/5000 | eta:  ?s          
  sampling                                          50/5000 | eta: 30s          
  sampling =                                       100/5000 | eta: 34s          
  sampling =                                       150/5000 | eta: 37s          
  sampling ==                                      200/5000 | eta: 36s          
  sampling ==                                      250/5000 | eta: 38s          
  sampling ==                                      300/5000 | eta: 36s          
  sampling ===                                     350/5000 | eta: 35s          
  sampling ===                                     400/5000 | eta: 35s          
  sampling ===                                     450/5000 | eta: 35s          
  sampling ====                                    500/5000 | eta: 34s          
  sampling ====                                    550/5000 | eta: 34s          
  sampling =====                                   600/5000 | eta: 34s          
  sampling =====                                   650/5000 | eta: 34s          
  sampling =====                                   700/5000 | eta: 33s          
  sampling ======                                  750/5000 | eta: 33s          
  sampling ======                                  800/5000 | eta: 32s          
  sampling ======                                  850/5000 | eta: 32s          
  sampling =======                                 900/5000 | eta: 32s          
  sampling =======                                 950/5000 | eta: 32s          
  sampling ========                               1000/5000 | eta: 31s          
  sampling ========                               1050/5000 | eta: 30s          
  sampling ========                               1100/5000 | eta: 30s          
  sampling =========                              1150/5000 | eta: 29s          
  sampling =========                              1200/5000 | eta: 29s          
  sampling ==========                             1250/5000 | eta: 29s          
  sampling ==========                             1300/5000 | eta: 28s          
  sampling ==========                             1350/5000 | eta: 27s          
  sampling ===========                            1400/5000 | eta: 27s          
  sampling ===========                            1450/5000 | eta: 27s          
  sampling ===========                            1500/5000 | eta: 26s          
  sampling ============                           1550/5000 | eta: 26s          
  sampling ============                           1600/5000 | eta: 26s          
  sampling =============                          1650/5000 | eta: 26s          
  sampling =============                          1700/5000 | eta: 26s          
  sampling =============                          1750/5000 | eta: 25s          
  sampling ==============                         1800/5000 | eta: 25s          
  sampling ==============                         1850/5000 | eta: 24s          
  sampling ==============                         1900/5000 | eta: 24s          
  sampling ===============                        1950/5000 | eta: 23s          
  sampling ===============                        2000/5000 | eta: 23s          
  sampling ================                       2050/5000 | eta: 23s          
  sampling ================                       2100/5000 | eta: 22s          
  sampling ================                       2150/5000 | eta: 22s          
  sampling =================                      2200/5000 | eta: 22s          
  sampling =================                      2250/5000 | eta: 21s          
  sampling =================                      2300/5000 | eta: 21s          
  sampling ==================                     2350/5000 | eta: 21s          
  sampling ==================                     2400/5000 | eta: 20s          
  sampling ===================                    2450/5000 | eta: 20s          
  sampling ===================                    2500/5000 | eta: 20s          
  sampling ===================                    2550/5000 | eta: 19s          
  sampling ====================                   2600/5000 | eta: 19s          
  sampling ====================                   2650/5000 | eta: 19s          
  sampling =====================                  2700/5000 | eta: 18s          
  sampling =====================                  2750/5000 | eta: 18s          
  sampling =====================                  2800/5000 | eta: 18s          
  sampling ======================                 2850/5000 | eta: 17s          
  sampling ======================                 2900/5000 | eta: 17s          
  sampling ======================                 2950/5000 | eta: 17s          
  sampling =======================                3000/5000 | eta: 16s          
  sampling =======================                3050/5000 | eta: 16s          
  sampling ========================               3100/5000 | eta: 15s          
  sampling ========================               3150/5000 | eta: 15s          
  sampling ========================               3200/5000 | eta: 15s          
  sampling =========================              3250/5000 | eta: 14s          
  sampling =========================              3300/5000 | eta: 14s          
  sampling =========================              3350/5000 | eta: 13s          
  sampling ==========================             3400/5000 | eta: 13s          
  sampling ==========================             3450/5000 | eta: 13s          
  sampling ===========================            3500/5000 | eta: 12s          
  sampling ===========================            3550/5000 | eta: 12s          
  sampling ===========================            3600/5000 | eta: 11s          
  sampling ============================           3650/5000 | eta: 11s          
  sampling ============================           3700/5000 | eta: 11s          
  sampling ============================           3750/5000 | eta: 10s          
  sampling =============================          3800/5000 | eta: 10s          
  sampling =============================          3850/5000 | eta:  9s          
  sampling ==============================         3900/5000 | eta:  9s          
  sampling ==============================         3950/5000 | eta:  9s          
  sampling ==============================         4000/5000 | eta:  8s          
  sampling ===============================        4050/5000 | eta:  8s          
  sampling ===============================        4100/5000 | eta:  7s          
  sampling ================================       4150/5000 | eta:  7s          
  sampling ================================       4200/5000 | eta:  6s          
  sampling ================================       4250/5000 | eta:  6s          
  sampling =================================      4300/5000 | eta:  6s          
  sampling =================================      4350/5000 | eta:  5s          
  sampling =================================      4400/5000 | eta:  5s          
  sampling ==================================     4450/5000 | eta:  4s          
  sampling ==================================     4500/5000 | eta:  4s          
  sampling ===================================    4550/5000 | eta:  4s          
  sampling ===================================    4600/5000 | eta:  3s          
  sampling ===================================    4650/5000 | eta:  3s          
  sampling ====================================   4700/5000 | eta:  2s          
  sampling ====================================   4750/5000 | eta:  2s          
  sampling ====================================   4800/5000 | eta:  2s          
  sampling =====================================  4850/5000 | eta:  1s          
  sampling =====================================  4900/5000 | eta:  1s          
  sampling ====================================== 4950/5000 | eta:  0s          
  sampling ====================================== 5000/5000 | eta:  0s          

The primary parameter of interest here is the \(\mu\) parameter, which is the overal risk estimate. We can draw a trace plot \(\mu\) and see how it’s samples look.

mcmc_trace(r, "mu")

We can also look at all of the other parameters in the model by drawing 95% credible intervals.

mcmc_intervals(r, prob_outer = 0.95)

We can see here that the individual city-specific estimates are generally positive, with the exception of a few cities. Also, the 95% credible interval for the overall risk estimate does not cover 0, providing strong evidence for a positive association between PM10 and mortality, on average across the 20 cities.

Summary

  • Bayesian hierarchical models provide a powerful class of tools for data analysis

  • These kinds of models have traditionally been of limited use due to their difficulty in implementation and fitting.

  • The greta package provides a way to fit hierarchical models in a manner that uses the R language for model specification (rather than requiring a new language be learned)

  • Many aspects of MCMC are better-understood today, allowing for more automation of the process in packages like greta.